function [net,tr] = train_marq_three_lay(trainParam,P,T,VV,TT)
% TRAIN_MARQ  
%       Marquardt Algorithm for an R-S1-S2 network
%       with tansigmoid hidden layer and linear
%       output layer.
%


if isstr(trainParam)
  switch (trainParam)
    case 'pdefaults',
      trainParam = [];
	  trainParam.mu_initial = 0.01;
	  trainParam.v = 10;
	  trainParam.maxmu = 1e10;
	  trainParam.mu_initial = 0.01;
      trainParam.max_fail = 5;
	  trainParam.mingrad = 1e-8;
      trainParam.max_epoch = 100;
      trainParam.err_goal = 0;
      trainParam.S1 = 4;
      trainParam.S2 = 2;
      trainParam.num_lay = 3;
      trainParam.tf1 = @nndlogsig;
      trainParam.df1 = @mdeltalog;
      trainParam.tf2 = @nndlogsig;
      trainParam.df2 = @mdeltalog;
      trainParam.show = 25;
      trainParam.time = inf;
      trainParam.ro = 0;
	  net = trainParam;
    otherwise,
	  error('Unrecognized code.')
  end
  return
end


% Set parameters

S1 = trainParam.S1;
S2 = trainParam.S2;
num_lay = trainParam.num_lay;
tf1 = trainParam.tf1;
df1 = trainParam.df1;
tf2 = trainParam.tf2;
df2 = trainParam.df2;
v = trainParam.v;
maxmu = trainParam.maxmu;
mu_initial = trainParam.mu_initial;
max_fail = trainParam.max_fail;
mingrad = trainParam.mingrad;
show = trainParam.show;
err_goal = trainParam.err_goal;
max_epoch = trainParam.max_epoch;
time = trainParam.time;
ro = trainParam.ro;
doValidationStop=true;
if (nargin>=4)
  doValidation = ~isempty(VV);
  if doValidation 
      if isfield(VV,'stop')
          doValidationStop=VV.stop;
      end
  end
else
  doValidation = false;
end
if (nargin>=5)
  doTest = ~isempty(TT);
else
  doTest = false;
end
this = 'train_marq_three_lay';

% INITIALIZE NETWORK ARCHITECTURE
%================================

% Set input vector size R, layer sizes S1 & S2, batch size Q.

[R,Q] = size(P); 

if num_lay == 2,
    [S2,Q] = size(T);
else
    [S3,Q] = size(T);
end

W10 = (2*rand(S1,R)-1)*0.5; B10 = (2*rand(S1,1)-1)*0.5;
W20 = (2*rand(S2,S1)-1)*0.5; B20 = (2*rand(S2,1)-1)*0.5;
if num_lay == 3,
    W30 = (2*rand(S3,S2)-1)*0.5; B30 = (2*rand(S3,1)-1)*0.5;
end

% [W10,B10] = nnnwlog(S1,P);

% DEFINE SIZES
RS = S1*R; RS1 = RS+1; RSS = RS + S1; RSS1 = RSS + 1;
RSS2 = RSS + S1*S2; RSS3 = RSS2 + 1; RSS4 = RSS2 + S2;

if num_lay == 3,
    RSS5 = RSS4 +1; RSS6 = RSS4 + S2*S3;
    RSS7 = RSS6 + 1; RSS8 = RSS6 + S3;
end


% INITIALIZE PARAMETERS
W1=W10;B1=B10;W2=W20;B2=B20;

if num_lay == 3,
    W3=W30;B3=B30;
end

dW1=W10;dB1=B10;dW2=W20;dB2=B20;

if num_lay == 3,
    dW3=W30;dB3=B30;
end


flag_stop=0;
stop = '';
startTime = clock;
mu=mu_initial;

meu=zeros(max_epoch,1);
mer=meu;grad=meu;
A1 = tf1(W1*P,B1);
if num_lay==3,
    A2 = tf1(W2*A1,B2);
    A3 = nndpurelin(W3*A2,B3);
    A = A3;
    ii=eye(RSS8);
    x = getX(W1,B1,W2,B2,W3,B3);
else
    A2 = nndpurelin(W2*A1,B2);
    A = A2;
    ii=eye(RSS4);
    x = getX(W1,B1,W2,B2);
end

E1 = T-A;

%f1 = sumsqr(E1) + ro*x'*x;
f1 = (sum(sum(E1.*E1))) + ro*x'*x;
perf = f1;

if (doValidation)
  A1v = nndtansig(W1*VV.P,B1);
  if num_lay == 3,
      A2v = tf1(W2*A1v,B2);
      Av = nndpurelin(W3*A2v,B3);
  else
      Av = nndpurelin(W2*A1v,B2);
  end

  E1v = VV.T-Av;
  %vperf=sumsqr(E1v) + ro*x'*x;
  vperf=(sum(sum(E1v.*E1v))) + ro*x'*x;
  VV.perf = vperf; 
  if num_lay==3,
      VV.net = getW(x,R,S1,S2,S3); 
  else
      VV.net = getW(x,R,S1,S2); 
  end
  VV.numFail = 0;
  VV.numFail = 0; tr.epoch = 0;
end


% MAIN LOOP

for epoch = 1:max_epoch,

% FIND JACOBIAN
  %A1 = kron(A1,ones(1,S2));
  
  if num_lay == 3,
      D3 = mdeltalin(A3);
      D2 = df2(A2,D3,W3);
      D1 = df1(A1,D2,W2);
      jac1 = learn_marq(kron(P,ones(1,S3)),D1);
      jac2 = learn_marq(A1,D2);
      jac3 = learn_marq(A2,D3);
      jac=[jac1,D1',jac2,D2',jac3,D3'];
  else
      D2 = mdeltalin(A2);
      D1 = df1(A1,D2,W2);
      jac1 = learn_marq(kron(P,ones(1,S2)),D1);
      jac2 = learn_marq(A1,D2);
      jac=[jac1,D1',jac2,D2'];
  end

% CHECK THE MAGNITUDE OF THE GRADIENT
  E1=E1(:);
  je=jac'*E1;
  if num_lay == 3,
      w = getX(W1,B1,W2,B2,W3,B3);
  else
      w = getX(W1,B1,W2,B2);
  end
  grd = 2*je + 2*ro*w;
  normgX = norm(grd);

  % Save results
  tr.mer(epoch)=f1;
  tr.meu(epoch)=mu;
  tr.grad(epoch)=normgX;
  tr.perf(epoch)=perf;
  if (doValidation)
    tr.vperf(epoch) = VV.perf;
  end
  if (doTest)
    A1v = tf1(W1*TT.P,B1);
    if num_lay == 3,
        A2v = tf2(W2*A1v,B2);
        Av = nndpurelin(W3*A2v,B3);
    else
        Av = nndpurelin(W2*A1v,B2);
    end
    E1v = TT.T-Av;
    %tperf=sumsqr(E1v) + ro*(x'*x);
    tperf=(sum(sum(E1v.*E1v))) + ro*x'*x;
    tr.tperf(epoch) = tperf;
  end

  % Stopping Criteria
  currentTime = etime(clock,startTime);
  if (f1 <= err_goal)
    stop = 'Performance goal met.';
  elseif (epoch == max_epoch)
    stop = 'Maximum epoch reached, performance goal was not met.';
  elseif (currentTime > time)
    stop = 'Maximum time elapsed, performance goal was not met.';
  elseif (normgX < mingrad)
    stop = 'Minimum gradient reached, performance goal was not met.';
  elseif (mu > maxmu)
    stop = 'Maximum MU reached, performance goal was not met.';
  elseif (doValidation) & (VV.numFail > max_fail)
    stop = 'Validation stop.';
  end
  
  % Progress
  if isfinite(show) & (~rem(epoch,show) | length(stop)),
      if isfinite(max_epoch) fprintf('Epoch %g/%g',epoch, max_epoch); end
      if isfinite(time) fprintf(', Time %4.1f%%',currentTime/time*100); end
      if isfinite(err_goal) fprintf(', %s %g/%g','Sum-squared Error',f1,err_goal); end
      if isfinite(mingrad) fprintf(', Gradient %g/%g',normgX,mingrad); end
      fprintf('\n')
      %flag_stop=plotperf(tr,goal,this,epoch);
      if length(stop) fprintf('%s, %s\n\n',this,stop); end
  end
 
  % Stop when criteria indicate its time
  if length(stop)
    if (doValidation)
    net = VV.net;
  end
    break
  end
  
% This section of code for checking the gradient calculation
if epoch==1,
  numParameters = length(w);
  A1v = tf1(W1*P,B1);
  if num_lay == 3,
      A2v = tf2(W2*A1v,B2);
      Av = nndpurelin(W3*A2v,B3);
  else
      Av = nndpurelin(W2*A1v,B2);
  end
  E1v = T-Av;
  %perf=sumsqr(E1v) + ro*w'*w;
  perf=(sum(sum(E1v.*E1v))) + ro*w'*w;
  eps = 0.000001;
  X_temp = w;
  gX = zeros(numParameters,1);
  for j=1:numParameters,
    X_temp(j)=w(j)+eps;
    if num_lay == 3,
        [net_temp] = getW(X_temp,R,S1,S2,S3);
    else
        [net_temp] = getW(X_temp,R,S1,S2);
    end
    A1v1 = tf1(net_temp.W1*P,net_temp.B1);
    if num_lay == 3,
        A2v1 = tf2(net_temp.W2*A1v1,net_temp.B2);
        Av1 = nndpurelin(net_temp.W3*A2v1,net_temp.B3);
    else
        Av1 = nndpurelin(net_temp.W2*A1v1,net_temp.B2);
    end
    E1v1 = T-Av1;
    %perf1=sumsqr(E1v1) + ro*X_temp'*X_temp;
    perf1=sum(sum(E1v1.*E1v1)) + ro*X_temp'*X_temp;
    X_temp(j)=w(j);
    gX(j) = (perf1-perf)/eps;
  end
  disp(['Sum square gradient error = ' num2str(sum(sum((gX-grd).^2)))])
end
% End of gradient checking


% INNER LOOP, INCREASE mu UNTIL THE ERRORS ARE REDUCED
  jj=jac'*jac;
  while mu < maxmu,
    dw=-(jj+ii*(mu+ro))\(je + ro*w);
    %dw=-(jj+ii*mu)\je;
    %dX = -(beta*jj + ii*(mu+alph)) \ (beta*je + alph*X);
    
    dW1(:)=dw(1:RS);
    dB1=dw(RS1:RSS);
    dW2(:)=dw(RSS1:RSS2);
    dB2=dw(RSS3:RSS4);
    W1n=W1+dW1;B1n=B1+dB1;W2n=W2+dW2;
    B2n=B2+dB2;
    A1 = tf1(W1n*P,B1n);
    if num_lay==3,
        dW3(:)=dw(RSS5:RSS6);
        dB3=dw(RSS7:RSS8);
        W3n=W3+dW3;B3n=B3+dB3;
        A2 = tf2(W2n*A1,B2n);
        A3 = nndpurelin(W3n*A2,B3n);
        A = A3;
        x = getX(W1n,B1n,W2n,B2n,W3n,B3n);
    else
        A2 = nndpurelin(W2n*A1,B2n);
        A = A2;
        x = getX(W1n,B1n,W2n,B2n);
    end

    E2 = T-A;
    %f2=sumsqr(E2) + ro*x'*x;
    f2 = sum(sum(E2.*E2)) + ro*x'*x;

    if (f2 < f1) 
      W1=W1n;B1=B1n;W2=W2n;B2=B2n;E1=E2;
      if num_lay==3,
          W3=W3n;B3=B3n;
      end
      f1=f2;
      w = x;
      mu = mu / v;
      if (mu < 1e-20)
        mu = 1e-20;
      end
      break   % Must be after the IF
    end
    mu = mu * v;
    perf = f1; 					
  end

  if (doValidation)
     A1v = tf1(W1*VV.P,B1);
    if num_lay == 3,
        A2v = tf2(W2*A1v,B2);
        Av = nndpurelin(W3*A2v,B3);
    else
        Av = nndpurelin(W2*A1v,B2);
    end

    E1v = VV.T-Av;
    %vperf=sumsqr(E1v) + ro*w'*w;
    vperf = sum(sum(E1v.*E1v)) + ro*w'*w;
    if (vperf < VV.perf)
      VV.perf = vperf; 
      if num_lay == 3,
          VV.net = getW(w,R,S1,S2,S3);
      else
          VV.net = getW(w,R,S1,S2); 
      end
      VV.numFail = 0; tr.epoch = epoch+1;
    elseif (vperf > VV.perf)
      VV.numFail = VV.numFail + 1;
    end
  end

end

% truncate vectors
tr.mer=tr.mer(1:epoch);
tr.meu=tr.meu(1:epoch);
tr.grad=tr.grad(1:epoch);
tr.perf=tr.perf(1:epoch);
if (doValidation)
  tr.vperf=tr.vperf(1:epoch);
end
if (doTest)
  tr.tperf=tr.tperf(1:epoch);
end

% Save Results
%=================
if (doValidation)
  net = VV.net;
else
  net.W1=W1; net.B1=B1; net.W2=W2; net.B2=B2;
end

end


%========================
function x = getX(W1,B1,W2,B2,W3,B3)
[S1,R] = size(W1);
[S2,S1] = size(W2);

    
RS = S1*R; RS1 = RS+1; RSS = RS + S1; RSS1 = RSS + 1;
RSS2 = RSS + S1*S2; RSS3 = RSS2 + 1; RSS4 = RSS2 + S2;
if nargin==6,
    [S3,S2] = size(W3);
    RSS5 = RSS4 +1; RSS6 = RSS4 + S2*S3; 
    RSS7 = RSS6 + 1; RSS8 = RSS6 + S3;
end


x(1:RS)=W1(:);
x(RS1:RSS)=B1;
x(RSS1:RSS2)=W2(:);
x(RSS3:RSS4)=B2;

if nargin == 6,
    [S3,S2] = size(W3);
    RSS5 = RSS4 +1; RSS6 = RSS4 + S2*S3; 
    RSS7 = RSS6 + 1; RSS8 = RSS6 + S3;
    x(RSS5:RSS6)=W3(:);
    x(RSS7:RSS8)=B3;
end

x=x(:);

end

%===============================
function [net] = getW(x,R,S1,S2,S3)

RS = S1*R; RS1 = RS+1; RSS = RS + S1; RSS1 = RSS + 1;
RSS2 = RSS + S1*S2; RSS3 = RSS2 + 1; RSS4 = RSS2 + S2;

net.W1 = zeros(S1,R);
net.W2 = zeros(S2,S1);
net.W1(:)=x(1:RS);
net.B1=x(RS1:RSS);
net.W2(:)=x(RSS1:RSS2);
net.B2=x(RSS3:RSS4);

if nargin==5,
    RSS5 = RSS4 +1; RSS6 = RSS4 + S2*S3; 
    RSS7 = RSS6 + 1; RSS8 = RSS6 + S3;
    net.W3 = zeros(S3,S2);
    net.B3 = zeros(S3,1);
    RSS5 = RSS4 +1; RSS6 = RSS4 + S2*S3; 
    RSS7 = RSS6 + 1; RSS8 = RSS6 + S3;
    net.W3(:)=x(RSS5:RSS6);
    net.B3=x(RSS7:RSS8);
end

end



